import jsonlines
import torch
import sys


def organize_ppls(data):
    results = {}
    for d in data:
        if d["K"] == "":
            continue
        if d["Q"] + "".join(d["O"]) not in results:
            results[d["Q"] + "".join(d["O"])] = {
                "Q": d["Q"],
                "O": d["O"],
                "A": d["A"],
                "E": d["E"] if "E" in d else "",
                "K": {d['K']: [d['ppl']]}
            }
        elif d["K"] in results[d["Q"] + "".join(d["O"])]["K"]:
            results[d["Q"] + "".join(d["O"])]["K"][d["K"]].append(d["ppl"])
        else:
            results[d["Q"] + "".join(d["O"])]["K"][d["K"]] = [d["ppl"]]
    return results.values()


model, dataset, rate, round_index = sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4]

try:
    rate = float(rate)
except:
    raise TypeError(f"Expect type of float, got {rate}")

data = [d for d in jsonlines.open(f"./{model}/round{round_index}/{dataset}/{dataset}_ppl.jsonl", "r")]
data = organize_ppls(data)
fo = jsonlines.open(f"./{model}/round{round_index}/{dataset}/{dataset}_probs.jsonl", "w")

for d in data:
    for key in d['K']:
        dist = torch.FloatTensor(d['K'][key])
        d['K'][key] = (1 - torch.softmax(dist * rate, dim=0)).tolist()
    fo.write(d)

fo.close()








